from jwcrypto import jwk
import uuid
from starlette.responses import JSONResponse
from starlette.requests import Request
from typing import Any

import jwt
import time
import json
import hashlib
import httpx

from jwt import PyJWK, PyJWKClient

from mindsdb.utilities import log

logger = log.getLogger(__name__)
AUTH_HEADER_PREFIX = "Bearer "


class PushNotificationAuth:
    def _calculate_request_body_sha256(self, data: dict[str, Any]):
        """Calculates the SHA256 hash of a request body.

        This logic needs to be same for both the agent who signs the payload and the client verifier.
        """
        body_str = json.dumps(
            data,
            ensure_ascii=False,
            allow_nan=False,
            indent=None,
            separators=(",", ":"),
        )
        return hashlib.sha256(body_str.encode()).hexdigest()


class PushNotificationSenderAuth(PushNotificationAuth):
    def __init__(self):
        self.public_keys = []
        self.private_key_jwk: PyJWK = None

    @staticmethod
    async def verify_push_notification_url(url: str) -> bool:
        async with httpx.AsyncClient(timeout=10) as client:
            try:
                validation_token = str(uuid.uuid4())
                response = await client.get(url, params={"validationToken": validation_token})
                response.raise_for_status()
                is_verified = response.text == validation_token

                logger.info(f"Verified push-notification URL: {url} => {is_verified}")
                return is_verified
            except Exception as e:
                logger.warning(f"Error during sending push-notification for URL {url}: {e}")

        return False

    def generate_jwk(self):
        key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
        self.public_keys.append(key.export_public(as_dict=True))
        self.private_key_jwk = PyJWK.from_json(key.export_private())

    def handle_jwks_endpoint(self, _request: Request):
        """Allow clients to fetch public keys."""
        return JSONResponse({"keys": self.public_keys})

    def _generate_jwt(self, data: dict[str, Any]):
        """JWT is generated by signing both the request payload SHA digest and time of token generation.

        Payload is signed with private key and it ensures the integrity of payload for client.
        Including iat prevents from replay attack.
        """

        iat = int(time.time())

        return jwt.encode(
            {
                "iat": iat,
                "request_body_sha256": self._calculate_request_body_sha256(data),
            },
            key=self.private_key_jwk,
            headers={"kid": self.private_key_jwk.key_id},
            algorithm="RS256",
        )

    async def send_push_notification(self, url: str, data: dict[str, Any]):
        jwt_token = self._generate_jwt(data)
        headers = {"Authorization": f"Bearer {jwt_token}"}
        async with httpx.AsyncClient(timeout=10) as client:
            try:
                response = await client.post(url, json=data, headers=headers)
                response.raise_for_status()
                logger.info(f"Push-notification sent for URL: {url}")
            except Exception as e:
                logger.warning(f"Error during sending push-notification for URL {url}: {e}")


class PushNotificationReceiverAuth(PushNotificationAuth):
    def __init__(self):
        self.public_keys_jwks = []
        self.jwks_client = None

    async def load_jwks(self, jwks_url: str):
        self.jwks_client = PyJWKClient(jwks_url)

    async def verify_push_notification(self, request: Request) -> bool:
        auth_header = request.headers.get("Authorization")
        if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
            return False

        token = auth_header[len(AUTH_HEADER_PREFIX) :]
        signing_key = self.jwks_client.get_signing_key_from_jwt(token)

        decode_token = jwt.decode(
            token,
            signing_key,
            options={"require": ["iat", "request_body_sha256"]},
            algorithms=["RS256"],
        )

        actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
        if actual_body_sha256 != decode_token["request_body_sha256"]:
            # Payload signature does not match the digest in signed token.
            raise ValueError("Invalid request body")

        if time.time() - decode_token["iat"] > 60 * 5:
            # Do not allow push-notifications older than 5 minutes.
            # This is to prevent replay attack.
            raise ValueError("Token is expired")

        return True
